Skip to content

Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics#3143

Open
vthumbe1503 wants to merge 18 commits into
NVIDIA:mainfrom
vthumbe1503:nvfp4_and_fp8_current_scaling
Open

Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics#3143
vthumbe1503 wants to merge 18 commits into
NVIDIA:mainfrom
vthumbe1503:nvfp4_and_fp8_current_scaling

Conversation

@vthumbe1503

Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 3 commits June 25, 2026 00:40
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Removed details about FP8 current scaling methods.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review June 25, 2026 00:57
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@greptile-apps

greptile-apps Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds graph-safe FP8 per-tensor current scaling support to GroupedLinear (both the module and ops code paths) on Hopper (CC 9.0) and Blackwell (CC 10.x/11.0), and simultaneously fixes a bug in cuBLAS grouped GEMM heuristic dimension derivation that caused avg_m/avg_n/avg_k to be computed with incorrect first/last-dim mapping across all three nvte_grouped_gemm* variants.

  • Float8CurrentScaling support: Both _GroupedLinear._is_grouped_tensor_path_supported and GroupedLinear._is_graph_safe_path_supported now fast-return True for Float8CurrentScalingQuantizer inputs before the Blackwell-only guard. The runtime restriction blocking float8_current_scaling with single_grouped_weight=True is lifted. The bug that unconditionally cleared grouped_x.rowwise_data (instead of guarding on columnwise_data is not None) is fixed in both code paths.
  • CUBLAS heuristics fix: Corrects swapped avg_m/avg_n assignments in nvte_grouped_gemm and nvte_grouped_gemm_with_discrete_inputA and flips first/last-dim selection for avg_k across all three functions to match cuBLAS column-major conventions.
  • Tests: Adds fp8_current_scaling parametrized cases to both test files and lowers the device-capability skip threshold from SM100 to SM90 where the feature is now supported.

Confidence Score: 5/5

Safe to merge; the correctness fixes are well-targeted and the cuBLAS heuristic changes affect only algorithm selection, not GEMM output.

The two substantive correctness fixes are straightforward and verified by the new test parametrizations. The cuBLAS avg_m/avg_n/avg_k changes are heuristics fed to the kernel selector and do not affect GEMM output correctness.

transformer_engine/common/gemm/cublaslt_grouped_gemm.cu — specifically the nvte_grouped_gemm_with_discrete_out heuristic derivation, which is used for weight-gradient GEMMs and may now produce less accurate M/N/K estimates than the old code did for that call pattern.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds Float8CurrentScalingQuantizer early-return to the grouped-tensor path gate; fixes the guarded-clearing of rowwise_data so it only executes when columnwise_data is not None.
transformer_engine/pytorch/ops/basic/grouped_linear.py Adds Float8CurrentScalingQuantizer early-return, adds single_grouped_weight parameter to _is_graph_safe_path_supported to block NVFP4 with grouped weights, removes the float8_current_scaling restriction from single_grouped_weight, and fixes the rowwise_data guard; docstring for _get_grouped_weight_for_gemm not updated.
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Fixes swapped avg_m/avg_n in nvte_grouped_gemm and nvte_grouped_gemm_with_discrete_inputA, and flips first/last-dim mapping for avg_k across all three GEMM variants to match cuBLAS column-major convention.
tests/pytorch/test_grouped_linear.py Adds fp8_current_scaling parametrized case to two test functions and lowers device-capability skip threshold from SM100 to SM90 for current-scaling recipe.
tests/pytorch/test_grouped_mlp.py Adds fp8_current_scaling and nvfp4_rht to the CUDA-graph-safe test, lowers minimum device capability to SM90, adds cuBLAS version checks, and fixes the single_grouped_weight FP8 skip to only apply to delayed scaling.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[GroupedLinear Forward Call] --> B{_is_graph_safe_path_supported?}
    B -- CC less than 9.0 --> C[Legacy split_quantize path]
    B -- CC 9.0 to 11.0 --> D{with_quantized_compute?}
    D -- No --> E{dtype BF16/FP16?}
    E -- Yes --> F[Grouped Tensor Path BF16/FP16]
    E -- No --> C
    D -- Yes --> G{All Float8CurrentScalingQuantizer?}
    G -- Yes --> H[Grouped Tensor Path FP8 Current Scaling NEW]
    G -- No --> I{CC 10.0 to 11.0?}
    I -- No --> C
    I -- Yes --> J{All MXFP8Quantizer?}
    J -- Yes --> K[Grouped Tensor Path MXFP8]
    J -- No --> L{All NVFP4+RHT AND NOT single_grouped_weight?}
    L -- Yes --> M[Grouped Tensor Path NVFP4]
    L -- No --> C
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[GroupedLinear Forward Call] --> B{_is_graph_safe_path_supported?}
    B -- CC less than 9.0 --> C[Legacy split_quantize path]
    B -- CC 9.0 to 11.0 --> D{with_quantized_compute?}
    D -- No --> E{dtype BF16/FP16?}
    E -- Yes --> F[Grouped Tensor Path BF16/FP16]
    E -- No --> C
    D -- Yes --> G{All Float8CurrentScalingQuantizer?}
    G -- Yes --> H[Grouped Tensor Path FP8 Current Scaling NEW]
    G -- No --> I{CC 10.0 to 11.0?}
    I -- No --> C
    I -- Yes --> J{All MXFP8Quantizer?}
    J -- Yes --> K[Grouped Tensor Path MXFP8]
    J -- No --> L{All NVFP4+RHT AND NOT single_grouped_weight?}
    L -- Yes --> M[Grouped Tensor Path NVFP4]
    L -- No --> C
Loading

Reviews (10): Last reviewed commit: "fix m and n" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
vthumbe1503 and others added 5 commits June 26, 2026 17:26
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
… weight being cuda graphable

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…3/TransformerEngine into nvfp4_and_fp8_current_scaling
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 changed the title Graph Safe Current Scaling Support for GroupedLinear Module/Ops Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics for Wgrad Jun 27, 2026
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 requested a review from denera June 29, 2026 16:42

@denera denera left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for two minor fixes/clarifications in the GroupedMLP tests.

Comment thread tests/pytorch/test_grouped_mlp.py Outdated
Comment thread tests/pytorch/test_grouped_mlp.py Outdated
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 requested a review from denera June 29, 2026 23:52
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

Comment thread transformer_engine/common/gemm/cublaslt_grouped_gemm.cu Outdated
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…3/TransformerEngine into nvfp4_and_fp8_current_scaling
@vthumbe1503 vthumbe1503 changed the title Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics for Wgrad Graph Safe Current Scaling Support for GroupedLinear Module/Ops + Fix CUBLAS GGEMM heuristics Jun 30, 2026
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI and perf checks.

@timmoon10

Copy link
Copy Markdown
Member

/te-ci

@denera denera left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as well!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants